/*
 * Copyright (C) 2007 by
 * 
 * 	Xuan-Hieu Phan
 *	hieuxuan@ecei.tohoku.ac.jp or pxhieu@gmail.com
 * 	Graduate School of Information Sciences
 * 	Tohoku University
 * 
 *  Cam-Tu Nguyen
 *  ncamtu@gmail.com
 *  College of Technology
 *  Vietnam National University, Hanoi
 *
 * JGibbsLDA is a free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published
 * by the Free Software Foundation; either version 2 of the License,
 * or (at your option) any later version.
 *
 * JGibbsLDA is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with JGibbsLDA; if not, write to the Free Software Foundation,
 * Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
 */

package com.ss.language.model.gibblda;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStreamReader;

import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;

public class Estimator {

	// output model
	protected Model trnModel;
	LDACmdOption option;

	public Estimator(LDACmdOption option) {
		LDACmdOption.curOption.set(option);
		init(option);
	}

	protected boolean init(LDACmdOption option) {
		this.option = option;
		trnModel = new Model();
		if (option.est) {
			if (!trnModel.initNewModel(option))
				return false;
			trnModel.data.localDict.writeWordMap(option.dir + File.separator + option.wordMapFileName);
			writeEachwordsEachWord(trnModel.data.docs);
		} else if (option.estc) {
			if (!trnModel.initEstimatedModel(option))
				return false;
		}

		return true;
	}

	/**
	 * 统计每个词在各个文章中出现的次数
	 * 
	 * @param docs
	 */
	private void writeEachwordsEachWord(Document[] docs) {
		if (docs != null && docs.length > 0) {
			BufferedReader br = null;
			try {
				br = new BufferedReader(new InputStreamReader(new FileInputStream(
						trnModel.data.localDict.getWordIdsFile()), "UTF-8"));
				for (String wordId = br.readLine(); wordId != null; wordId = br.readLine()) {
					wordId = wordId == null ? "" : wordId.trim();
					if (wordId.isEmpty()) {
						continue;
					}
					StringBuffer sb = new StringBuffer();
					for (Document doc : docs) {
						String[] words = doc.getAllWords();
						if (words != null && words.length > 0) {
							int times = 0;
							for (String w : words) {
								if (wordId.equals(w)) {
									times += 1;
								}
							}
							if (times > 0) {
								sb.append("(");
								sb.append(doc.getDocId());
								sb.append(":");
								sb.append(times);
								sb.append("),");
							}
						}
					}
					// 保存这一行的数据！
					if (sb.length() > 0) {
						File file = new File(option.dir + File.separator + option.wordMapFileName + "-statistic.txt");
						sb.insert(0, "[");
						sb.insert(sb.length() - 1, "]");
						FileUtils.write(file, sb.subSequence(0, sb.length() - 1) + IOUtils.LINE_SEPARATOR, "UTF-8",
								true);
					}
				}
			} catch (Exception e) {
				e.printStackTrace();
			} finally {
				if (br != null) {
					try {
						br.close();
					} catch (Exception e) {
					}
				}
			}
		}
	}

	public void estimate() {
		System.out.println("Sampling " + trnModel.niters + " iteration!");

		int lastIter = trnModel.liter;
		for (trnModel.liter = lastIter + 1; trnModel.liter < trnModel.niters + lastIter; trnModel.liter++) {
			System.out.println("Iteration " + trnModel.liter + " ...");

			// for all z_i
			for (int m = 0; m < trnModel.M; m++) {
				for (int n = 0; n < trnModel.data.docs[m].getLength(); n++) {
					// z_i = z[m][n]
					// sample from p(z_i|z_-i, w)
					int topic = sampling(m, n);
					trnModel.z[m].set(n, topic);
				}// end for each word
			}// end for each document

			if (option.savestep > 0) {
				if (trnModel.liter % option.savestep == 0) {
					System.out.println("Saving the model at iteration " + trnModel.liter + " ...");
					computeTheta();
					computePhi();
					trnModel.saveModel("model-" + Conversion.ZeroPad(trnModel.liter, 5));
				}
			}
		}// end iterations

		System.out.println("Gibbs sampling completed!\n");
		System.out.println("Saving the final model!\n");
		computeTheta();
		computePhi();
		trnModel.liter--;
		trnModel.saveModel("model-final");
	}

	/**
	 * Do sampling
	 * 
	 * @param m
	 *            document number
	 * @param n
	 *            word number
	 * @return topic id
	 */
	public int sampling(int m, int n) {
		// remove z_i from the count variable
		int topic = trnModel.z[m].get(n);
		int w = trnModel.data.docs[m].getWord(n);
		if (w < trnModel.V) {
			trnModel.nw[w][topic] -= 1;
		}
		trnModel.nd[m][topic] -= 1;
		trnModel.nwsum[topic] -= 1;
		trnModel.ndsum[m] -= 1;

		double Vbeta = trnModel.V * trnModel.beta;
		double Kalpha = trnModel.K * trnModel.alpha;

		// do multinominal sampling via cumulative method
		for (int k = 0; k < trnModel.K; k++) {
			if (w < trnModel.V) {
				trnModel.p[k] = (trnModel.nw[w][k] + trnModel.beta) / (trnModel.nwsum[k] + Vbeta)
						* (trnModel.nd[m][k] + trnModel.alpha) / (trnModel.ndsum[m] + Kalpha);
			}
		}

		// cumulate multinomial parameters
		for (int k = 1; k < trnModel.K; k++) {
			trnModel.p[k] += trnModel.p[k - 1];
		}

		// scaled sample because of unnormalized p[]
		double u = Math.random() * trnModel.p[trnModel.K - 1];

		for (topic = 0; topic < trnModel.K; topic++) {
			if (trnModel.p[topic] > u) // sample topic w.r.t distribution p
				break;
		}

		// add newly estimated z_i to count variables
		if (w < trnModel.V) {
			trnModel.nw[w][topic] += 1;
		}
		trnModel.nd[m][topic] += 1;
		trnModel.nwsum[topic] += 1;
		trnModel.ndsum[m] += 1;

		return topic;
	}

	public void computeTheta() {
		for (int m = 0; m < trnModel.M; m++) {
			for (int k = 0; k < trnModel.K; k++) {
				trnModel.theta.save(m, k, (trnModel.nd[m][k] + trnModel.alpha)
						/ (trnModel.ndsum[m] + trnModel.K * trnModel.alpha));
			}
		}
	}

	public void computePhi() {
		for (int k = 0; k < trnModel.K; k++) {
			for (int w = 0; w < trnModel.V; w++) {
				if (w < trnModel.V) {
					trnModel.phi.save(k, w, (trnModel.nw[w][k] + trnModel.beta)
							/ (trnModel.nwsum[k] + trnModel.V * trnModel.beta));
				}
			}
		}
	}
}
